Skip to content

Comments

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677

Open
sudhakarsingh27 wants to merge 15 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn
Open

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
sudhakarsingh27 wants to merge 15 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get Stats from cuDNN and Max tensor if return_max_logit=True. (Note that Stats = log(SumExp)+Max)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • fused_attn_f16_arbitrary_seqlen.cu
    • Removed references to SumExp tensor as it's not needed since cuDNN returns Stats by default.
    • set generate_stats=True which forces cuDNN to always return Stats tensor (needed in the backward pass)
  • transformer_engine/pytorch/cpp_extensions/fused_attn.py
    • Remove code that manually did Stats = log(SumExp) + Max since cuDNN returns Stats directly and TE doesn't need SumExp from cuDNN
  • Corresponding documentation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhakarsingh27 and others added 5 commits February 12, 2026 13:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Summary

This PR adapts TransformerEngine to leverage cuDNN's recent enhancement that allows returning any subset of {Stats, SumExp, Max}. The implementation now always retrieves Stats from cuDNN (which equals log(SumExp)+Max) and additionally obtains the Max tensor when return_max_logit=True.

Key changes:

  • Set generate_stats=true to force cuDNN to always return the Stats tensor needed for backward pass
  • Removed all references to SumExp tensor since it's no longer needed (Stats is returned directly)
  • Eliminated manual computation of Stats = log(SumExp) + Max in Python layer
  • Renamed generate_max_sum_exp to return_max_logit in FADescriptor_v1 for clarity
  • Updated documentation across all API functions to reflect new semantics
  • Adjusted tensor ordering in output: now returns Stats first, then Max (when requested)

All previous review feedback has been addressed, including updated comments in fused_attn.py and documentation in fused_attn.h.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - changes are well-structured and all previous feedback has been addressed
  • The implementation correctly adapts to cuDNN's new capability, maintains consistency across all layers (C++, Python bindings, documentation), and addresses all previous review comments. The logic flow is clean: Stats is always generated, Max is conditionally added when requested. Tensor ordering and field renaming are consistent throughout the codebase.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Core implementation updated to always generate Stats from cuDNN and optionally Max when return_max_logit=true; removed SumExp references and logic
transformer_engine/common/fused_attn/utils.h Renamed generate_max_sum_exp field to return_max_logit in FADescriptor_v1 struct for clarity and consistency
transformer_engine/common/include/transformer_engine/fused_attn.h Updated documentation for return_max_logit parameter to reflect new behavior (produces Max along with Stats rather than choosing between Max/SumExp or Stats)
transformer_engine/pytorch/cpp_extensions/fused_attn.py Removed manual Stats calculation (log(SumExp) + Max), now directly uses Stats from cuDNN; updated comments to reflect new tensor order (Stats, then Max)
transformer_engine/pytorch/csrc/extensions/attention.cpp Updated comments to reflect that return_max_logit=true now returns Stats and Max instead of Max and SumExp

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Fused Attention Forward Pass] --> B{generate_stats = true}
    B --> C[cuDNN SDPA Operation]
    C --> D[Always returns Stats tensor]
    D --> E{return_max_logit?}
    E -->|true| F[Also set Max output]
    E -->|false| G[Max = nullptr]
    F --> H[Output: O, Stats, Max]
    G --> I[Output: O, Stats]
    H --> J[Python Layer: aux_ctx_tensors = Stats]
    I --> K[Python Layer: aux_ctx_tensors = Stats]
    J --> L[Compute max_logit from Max tensor]
    L --> M[Return: O, aux_ctx_tensors, max_logit]
    K --> N[Return: O, aux_ctx_tensors]
    
    style D fill:#90EE90
    style F fill:#FFD700
    style H fill:#87CEEB
    style M fill:#87CEEB
Loading

Last reviewed commit: 56e46fd

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 17, 2026

Additional Comments (1)

transformer_engine/pytorch/cpp_extensions/fused_attn.py
Stale docstring: wrong formula for softmaxStats

The public docstring still describes softmaxStats as log(sum(e^(x - max(x)))), which is log(SumExp). However, with this PR, the returned tensor is cuDNN's Stats = log(SumExp) + Max, not just log(SumExp). This formula was already incorrect before this PR (the old code computed Max + log(SumExp) and stored it as stats), but the PR is an opportunity to correct it.

                       softmaxStats: torch.Tensor
                           log(sum(e^(x - max(x)))) + max(x), where x=Q*K.T (i.e. Stats = log(SumExp) + Max)
                           shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

stats = output_tensors[1] + torch.log(output_tensors[2])
# thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the "there's no typo here" :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately added it because I didn't believe it and checked the shapes myself :P

size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();

Copy link
Collaborator

@cyanguwa cyanguwa Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to make these changes in the "Aux_CTX_Tensors->size == 0" sections in _fwd/bwd_qkvpacked/kvpacked APIs as well. Please check. Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I don't need to because nvte_fused...qvpacked are in fused_attn.cpp which calls fused_attn_f16_arbitrary... just like regular nvte_fused_fwd/bwd

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype)
Copy link
Collaborator

@KshitijLakhani KshitijLakhani Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I understood this incorrectly, but isn't TE now also suppose to receive max from cuDNN directly (like stats, but with stats it is always true and with max it cn be toggled) rather than calling amax() in TE ?

(Sudhakar: Why am I able to update your comment? )

Copy link
Collaborator Author

@sudhakarsingh27 sudhakarsingh27 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuDNN returns Max ([b, h, sq, 1]) so it's an additional softmax statistic (apparently, the subset (Stats, Max) is enough for cuDNN bwd rather than the full set (Stats, SumExp, Max)).

Further, for muon, we need do amax on it to get a dimension [h] tensor. return_max_logit in TE controls whether to fetch Max from cuDNN.

Perf wise, it'd be nice for cuDNN to do additional reduction to return the [h] shaped tensor for muon as well but that's not the scope of this PR.

(Kshitij: looks like I can as well)

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from 21ca43a to becc3ad Compare February 20, 2026 19:41
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Additional Comments (1)

transformer_engine/common/include/transformer_engine/fused_attn.h
Entire file has been reformatted with unintentional 3-space indentation changes. This creates a large diff unrelated to the actual feature changes. Revert the formatting to match the original file structure.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from d4568db to 8f40cab Compare February 20, 2026 20:00
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants